import torch

from src.gfn.gfn import GFlowNet, LossTensor
from src.utils.trajectories import Trajectories

class DBGFlowNet(GFlowNet):
    """
    Detailed balance parameterization:
    flow model, forward policy, backward policy.

    Default behavior:
    - No parameter sharing between forward/backward policy
    """
    def __init__(self,
                 env,
                 config,
                 forward_model,
                 backward_model,
                 logF_model,
                 tied: bool = False,
                 ):
        super().__init__(env, config, forward_model, backward_model)
        self.logF_model = logF_model.to(self.device)
        assert self.logF_model.output_dim == 1, "LogF model must output a scalar."
        self.optimizer = self._init_optimizer(tied, include_logF=True)
        self.scheduler = self._init_scheduler(config["gfn"]["lr_schedule"])

    def _compute_loss_precursors(self, trajs: Trajectories, head=None):
        """
        Compute the log probabilities.
        """
        trajs.compute_logPF(self, head)
        trajs.compute_logPB(self)
        trajs.compute_logF(self)

    def loss(self, trajs: Trajectories, head=None) -> LossTensor:
        """
        Detailed balance loss.
        """

        self._compute_loss_precursors(trajs, head)

        # For every starting state i
        mean_losses = torch.zeros(trajs.length, device=trajs.device)
        for i in range(0, trajs.length):
            # Compute the detailed balance error across the batch and sum
            if i < trajs.length - 1:
                # Intermediate actions use the usual detailed balance loss 
                mean_losses[i] = (trajs.logF[:, i] + trajs.log_fullPF[:, i] - trajs.log_fullPB[:, i] - trajs.logF[:, i+1]).pow(2).mean()
            else:
                # This is a terminal action, we use reward matching
                mean_losses[i] = (trajs.logF[:, i] + trajs.log_fullPF[:, i] - trajs.log_fullPB[:, i] - trajs.log_rewards.clip(min=self.log_reward_clip_min)).pow(2).mean()

        loss = mean_losses.mean()
        
        return loss
    
